import os
import json
import pickle
import time
from collections import OrderedDict, defaultdict

import numpy as np
from glob import glob
from functools import partial
import nibabel as nib
from tqdm import tqdm
import argparse
import SimpleITK as sitk

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchio as tio

from thop import profile
from torchinfo import summary

from segment_anything.build_sam3D import sam_model_registry3D  #
from segment_anything.utils.transforms3D import ResizeLongestSide3D
from segment_anything import sam_model_registry
from segment_anything.modeling.image_encoder3D import ImageEncoderViT3D  #

from utils.click_method import get_next_click3D_torch_ritm, get_next_click3D_torch_2
from utils.data_loader import Dataset_Union_ALL_Val

# from fvcore.nn import FlopCountAnalysis
# from ptflops import get_model_complexity_info


parser = argparse.ArgumentParser()
parser.add_argument('-tdp', '--test_data_path', type=str,
                    default='/content/drive/MyDrive/paper_visual_results/totalseg')  #
parser.add_argument('-vp', '--vis_path', type=str,
                    default='/content/drive/MyDrive/paper_visual_results/totalseg0441')  #
parser.add_argument('-cp', '--checkpoint_path', type=str,
                    default="work_dir/finetune/sam_model_latest.pth")  # Teacher model ckpt download from provided ckpt link
parser.add_argument('-tp', '--tiny_vit_checkpoint', type=str, default='ckpt/6layers6heads_student',
                    help='Path to the image encoder checkpoint')  # FastSAM3D model ckpt download from provided ckpt link
parser.add_argument('-sn', '--save_name', type=str,
                    default='/content/drive/MyDrive/paper_visual_results/totalseg0041.py/')  #

parser.add_argument('--image_size', type=int, default=128)
parser.add_argument('--crop_size', type=int, default=128)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('-mt', '--model_type', type=str, default='vit_b_ori')
parser.add_argument('-nc', '--num_clicks', type=int, default=10)
parser.add_argument('-pm', '--point_method', type=str, default='default')
parser.add_argument('-dt', '--data_type', type=str, default='Tr')

parser.add_argument('--threshold', type=int, default=0)
parser.add_argument('--dim', type=int, default=3)
parser.add_argument('--split_idx', type=int, default=0)
parser.add_argument('--split_num', type=int, default=1)
parser.add_argument('--ft2d', action='store_true', default=False)
parser.add_argument('--seed', type=int, default=2023)
# parser.add_argument('--load_checkpoint', action='store_true', help='If set, load the model weights from the specified checkpoint path')

args = parser.parse_args()

SEED = args.seed
print("set seed as", SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

if torch.cuda.is_available():
    torch.cuda.init()

click_methods = {
    'default': get_next_click3D_torch_ritm,
    'ritm': get_next_click3D_torch_ritm,
    'random': get_next_click3D_torch_2,
}


def save_preprocessed_image(image3D_tensor, img_name, save_directory):
    os.makedirs(save_directory, exist_ok=True)
    image3D_np = image3D_tensor.cpu().numpy().squeeze()
    img_nifti = nib.Nifti1Image(image3D_np, affine=np.eye(4))
    save_path = os.path.join(save_directory, img_name.replace('.nii.gz', '_preprocessed.nii.gz'))
    nib.save(img_nifti, save_path)


def compute_iou(pred_mask, gt_semantic_seg):
    in_mask = np.logical_and(gt_semantic_seg, pred_mask)
    out_mask = np.logical_or(gt_semantic_seg, pred_mask)
    iou = np.sum(in_mask) / np.sum(out_mask)
    return iou


def compute_dice(mask_gt, mask_pred):
    """Compute soerensen-dice coefficient.
    Returns:
    the dice coeffcient as float. If both masks are empty, the result is NaN
    """
    volume_sum = mask_gt.sum() + mask_pred.sum()
    if volume_sum == 0:
        return np.NaN
    volume_intersect = (mask_gt & mask_pred).sum()
    return 2 * volume_intersect / volume_sum


def postprocess_masks(low_res_masks, image_size, original_size):
    ori_h, ori_w = original_size
    masks = F.interpolate(
        low_res_masks,
        (image_size, image_size),
        mode="bilinear",
        align_corners=False,
    )
    if args.ft2d and ori_h < image_size and ori_w < image_size:
        top = (image_size - ori_h) // 2
        left = (image_size - ori_w) // 2
        masks = masks[..., top: ori_h + top, left: ori_w + left]
        pad = (top, left)
    else:
        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
        pad = None
    return masks, pad


def sam_decoder_inference(target_size, points_coords, points_labels, model, image_embeddings, mask_inputs=None,
                          multimask=False):
    with torch.no_grad():
        sparse_embeddings, dense_embeddings = model.prompt_encoder(
            points=(points_coords.to(model.device), points_labels.to(model.device)),
            boxes=None,
            masks=mask_inputs,
        )

        low_res_masks, iou_predictions = model.mask_decoder(
            image_embeddings=image_embeddings,
            image_pe=model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask,
        )

    if multimask:
        max_values, max_indexs = torch.max(iou_predictions, dim=1)
        max_values = max_values.unsqueeze(1)
        iou_predictions = max_values
        low_res = []
        for i, idx in enumerate(max_indexs):
            low_res.append(low_res_masks[i:i + 1, idx])
        low_res_masks = torch.stack(low_res, 0)
    masks = F.interpolate(low_res_masks, (target_size, target_size), mode="bilinear", align_corners=False, )
    return masks, low_res_masks, iou_predictions


def repixel_value(arr, is_seg=False):
    if not is_seg:
        min_val = arr.min()
        max_val = arr.max()
        new_arr = (arr - min_val) / (max_val - min_val + 1e-10) * 255.
    return new_arr


def random_point_sampling(mask, get_point=1):
    if isinstance(mask, torch.Tensor):
        mask = mask.numpy()
    fg_coords = np.argwhere(mask == 1)[:, ::-1]
    bg_coords = np.argwhere(mask == 0)[:, ::-1]

    fg_size = len(fg_coords)
    bg_size = len(bg_coords)

    if get_point == 1:
        if fg_size > 0:
            index = np.random.randint(fg_size)
            fg_coord = fg_coords[index]
            label = 1
        else:
            index = np.random.randint(bg_size)
            fg_coord = bg_coords[index]
            label = 0
        return torch.as_tensor([fg_coord.tolist()], dtype=torch.float), torch.as_tensor([label], dtype=torch.int)
    else:
        num_fg = get_point // 2
        num_bg = get_point - num_fg
        fg_indices = np.random.choice(fg_size, size=num_fg, replace=True)
        bg_indices = np.random.choice(bg_size, size=num_bg, replace=True)
        fg_coords = fg_coords[fg_indices]
        bg_coords = bg_coords[bg_indices]
        coords = np.concatenate([fg_coords, bg_coords], axis=0)
        labels = np.concatenate([np.ones(num_fg), np.zeros(num_bg)]).astype(int)
        indices = np.random.permutation(get_point)
        coords, labels = torch.as_tensor(coords[indices], dtype=torch.float), torch.as_tensor(labels[indices],
                                                                                              dtype=torch.int)
        return coords, labels


def finetune_model_predict2D(img3D, gt3D, sam_model_tune, target_size=256, click_method='random', device='cuda',
                             num_clicks=1, prev_masks=None):
    pred_list = []
    iou_list = []
    dice_list = []

    slice_mask_list = defaultdict(list)

    img3D = torch.repeat_interleave(img3D, repeats=3, dim=1)  # 1 channel -> 3 channel (align to RGB)

    click_points = []
    click_labels = []
    for slice_idx in tqdm(range(img3D.size(-1)), desc="transverse slices", leave=False):
        img2D, gt2D = repixel_value(img3D[..., slice_idx]), gt3D[..., slice_idx]

        if (gt2D == 0).all():
            empty_result = torch.zeros(list(gt3D.size()[:-1]) + [1]).to(device)
            for iter in range(num_clicks):
                slice_mask_list[iter].append(empty_result)
            continue

        img2D = F.interpolate(img2D, (target_size, target_size), mode="bilinear", align_corners=False)
        gt2D = F.interpolate(gt2D.float(), (target_size, target_size), mode="nearest").int()

        img2D, gt2D = img2D.to(device), gt2D.to(device)
        img2D = (img2D - img2D.mean()) / img2D.std()

        with torch.no_grad():
            image_embeddings = sam_model_tune.image_encoder(img2D.float())

        points_co, points_la = torch.zeros(1, 0, 2).to(device), torch.zeros(1, 0).to(device)
        low_res_masks = None
        gt_semantic_seg = gt2D[0, 0].to(device)
        true_masks = (gt_semantic_seg > 0)
        for iter in range(num_clicks):
            if low_res_masks is None:
                pred_masks = torch.zeros_like(true_masks).to(device)
            else:
                pred_masks = (prev_masks[0, 0] > 0.0).to(device)
            fn_masks = torch.logical_and(true_masks, torch.logical_not(pred_masks))
            fp_masks = torch.logical_and(torch.logical_not(true_masks), pred_masks)
            mask_to_sample = torch.logical_or(fn_masks, fp_masks)
            new_points_co, _ = random_point_sampling(mask_to_sample.cpu(), get_point=1)
            new_points_la = torch.Tensor([1]).to(torch.int64) if (
                true_masks[new_points_co[0, 1].int(), new_points_co[0, 0].int()]) else torch.Tensor([0]).to(torch.int64)
            new_points_co, new_points_la = new_points_co[None].to(device), new_points_la[None].to(device)
            points_co = torch.cat([points_co, new_points_co], dim=1)
            points_la = torch.cat([points_la, new_points_la], dim=1)
            prev_masks, low_res_masks, iou_predictions = sam_decoder_inference(
                target_size, points_co, points_la, sam_model_tune, image_embeddings,
                mask_inputs=low_res_masks, multimask=True)
            click_points.append(new_points_co)
            click_labels.append(new_points_la)

            slice_mask, _ = postprocess_masks(low_res_masks, target_size, (gt3D.size(2), gt3D.size(3)))
            slice_mask_list[iter].append(slice_mask[..., None])  # append (B, C, H, W, 1)

    for iter in range(num_clicks):
        medsam_seg = torch.cat(slice_mask_list[iter], dim=-1).cpu().numpy().squeeze()
        medsam_seg = medsam_seg > sam_model_tune.mask_threshold
        medsam_seg = medsam_seg.astype(np.uint8)

        pred_list.append(medsam_seg)
        iou_list.append(round(compute_iou(medsam_seg, gt3D[0][0].detach().cpu().numpy()), 4))
        dice_list.append(round(compute_dice(gt3D[0][0].detach().cpu().numpy().astype(np.uint8), medsam_seg), 4))

    return pred_list, click_points, click_labels, iou_list, dice_list


def finetune_model_predict3D(tiny_vit, img3D, gt3D, sam_model_tune, device='cuda', click_method='random', num_clicks=10,
                             prev_masks=None):
    torch.cuda.reset_max_memory_allocated(device)
    encoder_time = 0  #
    decoder_time = []
    img3D = norm_transform(img3D.squeeze(dim=1))  # (N, C, W, H, D)
    img3D = img3D.unsqueeze(dim=1)
    click_points = []
    click_labels = []

    pred_list = []
    iou_list = []
    dice_list = []
    FLOPS = np.zeros(num_clicks)
    if prev_masks is None:
        prev_masks = torch.zeros_like(gt3D).to(device)
    low_res_masks = F.interpolate(prev_masks.float(),
                                  size=(args.crop_size // 4, args.crop_size // 4, args.crop_size // 4))

    FLOPS = np.zeros(num_clicks)

    with torch.no_grad():
        image_embedding = tiny_vit(img3D.to(device))
    image_embedding = image_embedding[-1]  # (1, 384, 16, 16, 16)
    print(profile(tiny_vit, (img3D.to(device),))[0])  # FLOPs for image encoder part
    input_size = (1, 128, 128, 128)
    encoder_time = time
    # print(time) #
    memory_before = torch.cuda.max_memory_allocated(device)
    print(memory_before)  #
    torch.cuda.reset_max_memory_allocated(device)
    for num_click in range(num_clicks):
        #
        with torch.no_grad():
            if num_click > 1:
                click_method = "random"
            batch_points, batch_labels = click_methods[click_method](prev_masks.to(device), gt3D.to(device))

            points_co = torch.cat(batch_points, dim=0).to(device)
            points_la = torch.cat(batch_labels, dim=0).to(device)

            click_points.append(points_co)
            click_labels.append(points_la)

            points_input = points_co
            labels_input = points_la

            sparse_embeddings, dense_embeddings = sam_model_tune.prompt_encoder(
                points=[points_input, labels_input],
                boxes=None,  #
                masks=low_res_masks.to(device),
            )
            FLOPS[num_click] += \
                profile(sam_model_tune.prompt_encoder, ([points_input, labels_input], None, low_res_masks.to(device),))[
                    0]
            start_time = time.time()

            low_res_masks, _ = sam_model_tune.mask_decoder(
                image_embeddings=image_embedding.to(device),  # (B, 384, 64, 64, 64)
                image_pe=sam_model_tune.prompt_encoder.get_dense_pe(),  # (1, 384, 64, 64, 64)
                sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 384)
                dense_prompt_embeddings=dense_embeddings,  # (B, 384, 64, 64, 64)
                multimask_output=False,
            )
            FLOPS[num_click] += profile(sam_model_tune.mask_decoder, (
                image_embedding, sam_model_tune.prompt_encoder.get_dense_pe(), sparse_embeddings, dense_embeddings,
                False,))[0]

            end_time = time.time()
            decoder_time.append(end_time - start_time)
            memory_decoder = torch.cuda.max_memory_allocated(device)  #
            prev_masks = F.interpolate(low_res_masks, size=gt3D.shape[-3:], mode='trilinear', align_corners=False)

            medsam_seg_prob = torch.sigmoid(prev_masks)  # (B, 1, 64, 64, 64)
            # convert prob to mask
            medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
            medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
            pred_list.append(medsam_seg)

            iou_list.append(round(compute_iou(medsam_seg, gt3D[0][0].detach().cpu().numpy()), 4))
            dice_list.append(round(compute_dice(gt3D[0][0].detach().cpu().numpy().astype(np.uint8), medsam_seg), 4))
    return pred_list, click_points, click_labels, iou_list, dice_list, encoder_time, decoder_time, memory_before, memory_decoder, FLOPS


if __name__ == "__main__":
    st = time.time()

    all_dataset_paths = glob(os.path.join(args.test_data_path))

    print(args.test_data_path)
    all_dataset_paths = list(filter(os.path.isdir, all_dataset_paths))

    print("get", len(all_dataset_paths), "datasets")

    infer_transform = [
        tio.ToCanonical(),
        tio.CropOrPad(mask_name='label', target_shape=(args.crop_size, args.crop_size, args.crop_size)),
    ]

    test_dataset = Dataset_Union_ALL_Val(
        paths=all_dataset_paths,
        mode="Val",
        data_type=args.data_type,
        transform=tio.Compose(infer_transform),
        threshold=0,
        split_num=args.split_num,
        split_idx=args.split_idx,
        pcc=False,
    )
    test_dataloader = DataLoader(
        dataset=test_dataset,
        sampler=None,
        batch_size=1,
        shuffle=True
    )

    checkpoint_path = args.checkpoint_path

    device = args.device
    print("device:", device)

    if args.dim == 3:
        sam_model_tune = sam_model_registry3D[args.model_type](checkpoint=None).to(device)
        if checkpoint_path is not None:
            model_dict = torch.load(checkpoint_path, map_location=device)
            state_dict = model_dict['model_state_dict']
            sam_model_tune.load_state_dict(state_dict)
    elif args.dim == 2:
        args.sam_checkpoint = args.checkpoint_path
        sam_model_tune = sam_model_registry[args.model_type](args).to(device)
    # change checkpoint here
    tiny_vit_checkpoint_path = args.tiny_vit_checkpoint  # Load image encoder weight here, download form the checkpoint link
    tiny_vit = ImageEncoderViT3D(
        depth=6,
        embed_dim=768,
        img_size=128,
        mlp_ratio=4,
        norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
        num_heads=6,
        patch_size=16,
        qkv_bias=True,
        use_rel_pos=True,
        global_attn_indexes=[2, 5, 8, 11],
        window_size=0,
        out_chans=384,
        skip_layer=2,
    )

    model_dict = torch.load(tiny_vit_checkpoint_path, map_location=args.device)
    state_dict = model_dict['model_state_dict']
    tiny_vit.load_state_dict(state_dict)
    print(f"Loaded weights from {args.checkpoint_path}")

    tiny_vit = tiny_vit.to(device)
    sam_trans = ResizeLongestSide3D(sam_model_tune.image_encoder.img_size)
    all_iou_list = []
    all_dice_list = []

    out_dice = dict()
    out_dice_all = OrderedDict()
    encoder_times = []
    decoder_times = []
    average_decoder_times = []
    memory_befores = []
    memory_decoders = []
    FLOPSS = []
    for batch_data in tqdm(test_dataloader):
        image3D, gt3D, img_name = batch_data

        print(gt3D.shape)

        sz = image3D.size()
        if sz[2] < args.crop_size or sz[3] < args.crop_size or sz[4] < args.crop_size:
            print("[ERROR] wrong size", sz, "for", img_name)
        modality = os.path.basename(os.path.dirname(os.path.dirname(os.path.dirname(img_name[0]))))
        dataset = os.path.basename(os.path.dirname(os.path.dirname(img_name[0])))
        vis_root: str = os.path.join(os.path.dirname(__file__), args.vis_path, modality, dataset)
        gt_save_directory = os.path.join(vis_root, "gt_masks")
        os.makedirs(gt_save_directory, exist_ok=True)

        click_suffix = f"_pred{args.num_clicks - 1}.nii.gz"
        pred_path = os.path.join(vis_root, os.path.basename(img_name[0]).replace(".nii.gz", click_suffix))

        if 1 == 0:
            iou_list, dice_list = [], []
            for iter in range(args.num_clicks):
                curr_pred_path = os.path.join(vis_root,
                                              os.path.basename(img_name[0]).replace(".nii.gz", f"_pred{iter}.nii.gz"))
                medsam_seg = sitk.GetArrayFromImage(sitk.ReadImage(curr_pred_path))
                iou_list.append(round(compute_iou(medsam_seg, gt3D[0][0].detach().cpu().numpy()), 4))
                dice_list.append(round(compute_dice(gt3D[0][0].detach().cpu().numpy().astype(np.uint8), medsam_seg), 4))
        else:
            norm_transform = tio.ZNormalization(masking_method=lambda x: x > 0)
            if args.dim == 3:
                seg_mask_list, points, labels, iou_list, dice_list, encoder_time, decoder_time, memory_before, memory_decoder, FLOPS = finetune_model_predict3D(
                    tiny_vit,
                    image3D, gt3D, sam_model_tune, device=device,
                    click_method=args.point_method, num_clicks=args.num_clicks,
                    prev_masks=None)
            elif args.dim == 2:
                seg_mask_list, points, labels, iou_list, dice_list = finetune_model_predict2D(
                    image3D, gt3D, sam_model_tune, device=device, target_size=args.image_size,
                    click_method=args.point_method, num_clicks=args.num_clicks,
                    prev_masks=None)
            os.makedirs(vis_root, exist_ok=True)
            points = [p.cpu().numpy() for p in points]
            labels = [l.cpu().numpy() for l in labels]
            pt_info = dict(points=points, labels=labels)
            print("save to", os.path.join(vis_root, os.path.basename(img_name[0]).replace(".nii.gz", "_pred.nii.gz")))
            pt_path = os.path.join(vis_root, os.path.basename(img_name[0]).replace(".nii.gz", "_pt.pkl"))
            pickle.dump(pt_info, open(pt_path, "wb"))
            for idx, pred3D in enumerate(seg_mask_list):
                out = sitk.GetImageFromArray(pred3D)
                sitk.WriteImage(out, os.path.join(vis_root, os.path.basename(img_name[0]).replace(".nii.gz",
                                                                                                  f"_pred{idx}.nii.gz")))

        per_iou = max(iou_list)
        all_iou_list.append(per_iou)
        all_dice_list.append(max(dice_list))
        print(dice_list)
        out_dice[img_name] = max(dice_list)
        cur_dice_dict = OrderedDict()
        encoder_times.append(encoder_time)
        average_decoder_times.append(np.average(decoder_time))
        FLOPSS.append(FLOPS)
        decoder_times.append(decoder_time)
        memory_befores.append(memory_before)
        memory_decoders.append(memory_decoder)
        for i, dice in enumerate(dice_list):
            cur_dice_dict[f'{i}'] = dice
        out_dice_all[img_name[0]] = cur_dice_dict

    print('Mean IoU : ', sum(all_iou_list) / len(all_iou_list))
    print('Mean Dice: ', sum(all_dice_list) / len(all_dice_list))
    print(sum(encoder_times) / len(encoder_times))
    final_dice_dict = OrderedDict()
    for k, v in out_dice_all.items():
        organ = k.split('/')[-4]
        final_dice_dict[organ] = OrderedDict()
    for k, v in out_dice_all.items():
        organ = k.split('/')[-4]
        final_dice_dict[organ][k] = v

    if args.split_num > 1:
        args.save_name = args.save_name.replace('.py', f'_s{args.split_num}i{args.split_idx}.py')

    print("Save to", args.save_name)
    with open(args.save_name, 'w') as f:
        f.writelines(f'# mean dice: \t{np.mean(all_dice_list)}\n')
        f.writelines('dice_Ts = {')
        for k, v in out_dice.items():
            f.writelines(f'\'{str(k[0])}\': {v},\n')
        f.writelines('encoder_time')
        for i in encoder_times:
            f.writelines(f'\'{str(i)},\n')
        f.writelines('decode')
        for j in decoder_times:
            for b in j:
                f.writelines(f'\'{str(b)},\n')
        f.writelines('average decode')
        for j in average_decoder_times:
            f.writelines(f'\'{str(j)},\n')
        f.writelines('flops')
        for j in FLOPS:
            f.writelines(f'\'{str(j)},\n')
        for j in memory_befores:
            f.writelines(f'\'{str(j)},\n')
        for j in memory_decoders:
            f.writelines(f'\'{str(j)},\n')
        f.writelines('}')
    with open(args.save_name.replace('.py', '.json'), 'w') as f:
        json.dump(final_dice_dict, f, indent=4)

    print("Done")
    eo = time.time() - st
    print(eo)
